-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
graph: backend: dnnl: support select with binary primitive #2349
base: main
Are you sure you want to change the base?
Conversation
6be21c9
to
4ea4e67
Compare
make test |
4ea4e67
to
b94bafa
Compare
make test |
b94bafa
to
458e748
Compare
make test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have any performance data to share?
@@ -2266,7 +2266,8 @@ status_t binary_canonicalization(std::shared_ptr<subgraph_t> &sg) { | |||
int32_t src1_ndims = src1_lt.ndims; | |||
int32_t target_ndims = std::max(src0_ndims, src1_ndims); | |||
std::vector<int32_t> in_ndims {src0_ndims, src1_ndims}; | |||
for (size_t i = 0; i < cur_op->num_inputs(); ++i) { | |||
std::vector<size_t> input_indices = {0, 1}; | |||
for (auto i : input_indices) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct? Previously num_inputs()
is 2 - 32 per the schema definition. Now the code only handles the first two?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pass is applied before postop fusion pass, so input number is always 2 before. For this PR, although binary select has three inputs, since cond
dims has been promised to be the same that of src0
by pass decompose_select_to_binary_ops
, we only need to unsqueeze src0
and src1
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cond
dims has been promised to be the same of src0
, then it should fall into the condition of if (in_ndims[i] == target_ndims) { continue; }
, so no unsqueeze inserted. If this is the case, no need to limit the input_indices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in_ndims
only has two elements, the access for the third element is not legal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, then it seems the original code is designed for 2 elements
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pass is applied before postop fusion pass, so input number is always 2 before. For this PR, although binary select has three inputs, since
cond
dims has been promised to be the same that ofsrc0
by passdecompose_select_to_binary_ops
, we only need to unsqueezesrc0
andsrc1
.
This explanation looks suspicious as the code has quite a few assumption to work properly. You may need to at least add comment for that.
BTW: I feel for (size_t i : {0, 1}) { .... }
should work without defining input_indices
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed: I keep the original for loop and make it skip the unsqueeze process when iterating the third input.
Sure, I have attached it to the PR description. |
458e748
to
f8262e0
Compare
make test |
f8262e0
to
325bca9
Compare
make test |
@@ -2266,7 +2266,8 @@ status_t binary_canonicalization(std::shared_ptr<subgraph_t> &sg) { | |||
int32_t src1_ndims = src1_lt.ndims; | |||
int32_t target_ndims = std::max(src0_ndims, src1_ndims); | |||
std::vector<int32_t> in_ndims {src0_ndims, src1_ndims}; | |||
for (size_t i = 0; i < cur_op->num_inputs(); ++i) { | |||
std::vector<size_t> input_indices = {0, 1}; | |||
for (auto i : input_indices) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, then it seems the original code is designed for 2 elements
325bca9
to
6694b8c
Compare
make test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please separate benchdnn inputs changes into a standalone commit.
@@ -2266,7 +2266,8 @@ status_t binary_canonicalization(std::shared_ptr<subgraph_t> &sg) { | |||
int32_t src1_ndims = src1_lt.ndims; | |||
int32_t target_ndims = std::max(src0_ndims, src1_ndims); | |||
std::vector<int32_t> in_ndims {src0_ndims, src1_ndims}; | |||
for (size_t i = 0; i < cur_op->num_inputs(); ++i) { | |||
std::vector<size_t> input_indices = {0, 1}; | |||
for (auto i : input_indices) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pass is applied before postop fusion pass, so input number is always 2 before. For this PR, although binary select has three inputs, since
cond
dims has been promised to be the same that ofsrc0
by passdecompose_select_to_binary_ops
, we only need to unsqueezesrc0
andsrc1
.
This explanation looks suspicious as the code has quite a few assumption to work properly. You may need to at least add comment for that.
BTW: I feel for (size_t i : {0, 1}) { .... }
should work without defining input_indices
.
6694b8c
to
66e2b1f
Compare
Description
cond
input is defined for dnnl binary opcond
input, we use binary select primitive for non-broadcast case only, the lowering logic is: always lower select to binary primitive and then decide which impl path to use in passdecompose_select_to_multiple_binary_ops
and decompose it to multiple binary ops if necessary.Performance
relative perf:
platform: Intel(R) Xeon(R) Platinum 8490H